"""
Phase 3: Automation -> CA Mediation
=====================================
Tests whether automation / capital intensity mediates the demographic
channel into current accounts. Baron & Kenny (1986) mediation:
  Step 1: Z -> Mediator (capital_intensity, labor_productivity)
  Step 2: Z -> CA (baseline)
  Step 3: Z + Mediator -> CA (attenuation test)
  Step 4: Mediator -> CA (direct link)
Reports attenuation percentages. OECD subsample for robustness.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]

CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']


# ── Helpers ──────────────────────────────────────────────────────────

def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    gls.fit(y, X, sub['iso3'].values, sub['year'].values)

    names = feature_names if feature_names else x_vars
    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(names):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def write_table(results, filename, title):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def compute_attenuation(baseline, mediated, var='Z_1'):
    """Compute attenuation percentage for a given variable."""
    base_key = f'{var}_coef'
    if base_key not in baseline or base_key not in mediated:
        return None
    b_base = baseline[base_key]
    b_med = mediated[base_key]
    if abs(b_base) < 1e-10:
        return None
    atten = (b_base - b_med) / b_base * 100
    return atten


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 3: AUTOMATION -> CA MEDIATION")
    print("=" * 70)

    df = pd.read_csv(DATA / "automation_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # ══════════════════════════════════════════════════════════════════
    # MEDIATION 1: Capital Intensity (Investment/GDP)
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("MEDIATION 1: CAPITAL INTENSITY -> CA")
    print("=" * 50)

    results_ki = []

    # Step 1: Z -> capital_intensity (first stage)
    print("\n--- Step 1: Z -> capital_intensity ---")
    r_step1 = run_panel_gls(df, 'capital_intensity',
                            ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                            'Step 1: Z->K')
    if r_step1: results_ki.append(r_step1)

    # Step 2: Z -> CA (baseline, no mediator)
    print("\n--- Step 2: Z -> ca_gdp (baseline) ---")
    r_step2 = run_panel_gls(df, 'ca_gdp',
                            ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                            'Step 2: Z->CA')
    if r_step2: results_ki.append(r_step2)

    # Step 3: Z + capital_intensity -> CA (mediation)
    print("\n--- Step 3: Z + capital_intensity -> ca_gdp ---")
    r_step3 = run_panel_gls(df, 'ca_gdp',
                            ['Z_1', 'Z_2', 'Z_3', 'capital_intensity'] + CONTROLS,
                            'Step 3: Z+K->CA')
    if r_step3: results_ki.append(r_step3)

    # Step 4: capital_intensity only -> CA
    print("\n--- Step 4: capital_intensity -> ca_gdp ---")
    r_step4 = run_panel_gls(df, 'ca_gdp',
                            ['capital_intensity'] + CONTROLS,
                            'Step 4: K->CA')
    if r_step4: results_ki.append(r_step4)

    write_table(results_ki, "automation_ca_mediation.md",
                "Capital Intensity Mediation: Z -> Investment/GDP -> CA")

    # Attenuation report
    if r_step2 and r_step3:
        print("\n--- Attenuation Report (Capital Intensity) ---")
        for zvar in ['Z_1', 'Z_2', 'Z_3']:
            atten = compute_attenuation(r_step2, r_step3, zvar)
            if atten is not None:
                b_base = r_step2[f'{zvar}_coef']
                b_med = r_step3[f'{zvar}_coef']
                print(f"  {zvar}: {b_base:.4f} -> {b_med:.4f} "
                      f"(attenuation: {atten:.1f}%)")

    # ══════════════════════════════════════════════════════════════════
    # MEDIATION 2: Labor Productivity (Log GDP per capita)
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("MEDIATION 2: LABOR PRODUCTIVITY -> CA")
    print("=" * 50)

    results_lp = []

    # Step 1: Z -> log_labor_productivity
    print("\n--- Step 1: Z -> log_labor_productivity ---")
    r_step1_lp = run_panel_gls(df, 'log_labor_productivity',
                               ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                               'Step 1: Z->LP')
    if r_step1_lp: results_lp.append(r_step1_lp)

    # Step 2: Z -> CA (same baseline)
    if r_step2: results_lp.append(r_step2)

    # Step 3: Z + log_labor_productivity -> CA
    print("\n--- Step 3: Z + log_labor_productivity -> ca_gdp ---")
    r_step3_lp = run_panel_gls(df, 'ca_gdp',
                               ['Z_1', 'Z_2', 'Z_3', 'log_labor_productivity'] + CONTROLS,
                               'Step 3: Z+LP->CA')
    if r_step3_lp: results_lp.append(r_step3_lp)

    # Step 4: log_labor_productivity only -> CA
    print("\n--- Step 4: log_labor_productivity -> ca_gdp ---")
    r_step4_lp = run_panel_gls(df, 'ca_gdp',
                               ['log_labor_productivity'] + CONTROLS,
                               'Step 4: LP->CA')
    if r_step4_lp: results_lp.append(r_step4_lp)

    write_table(results_lp, "productivity_ca_mediation.md",
                "Productivity Mediation: Z -> Log GDP/capita -> CA")

    # Attenuation report
    if r_step2 and r_step3_lp:
        print("\n--- Attenuation Report (Labor Productivity) ---")
        for zvar in ['Z_1', 'Z_2', 'Z_3']:
            atten = compute_attenuation(r_step2, r_step3_lp, zvar)
            if atten is not None:
                b_base = r_step2[f'{zvar}_coef']
                b_med = r_step3_lp[f'{zvar}_coef']
                print(f"  {zvar}: {b_base:.4f} -> {b_med:.4f} "
                      f"(attenuation: {atten:.1f}%)")

    # ══════════════════════════════════════════════════════════════════
    # MEDIATION 3: Labor Share (if PWT available)
    # ══════════════════════════════════════════════════════════════════
    if 'labsh' in df.columns and df['labsh'].dropna().shape[0] > 100:
        print("\n" + "=" * 50)
        print("MEDIATION 3: LABOR SHARE -> CA")
        print("=" * 50)

        results_ls = []

        print("\n--- Step 1: Z -> labsh ---")
        r1_ls = run_panel_gls(df, 'labsh',
                              ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                              'Step 1: Z->labsh')
        if r1_ls: results_ls.append(r1_ls)

        if r_step2: results_ls.append(r_step2)

        print("\n--- Step 3: Z + labsh -> ca_gdp ---")
        r3_ls = run_panel_gls(df, 'ca_gdp',
                              ['Z_1', 'Z_2', 'Z_3', 'labsh'] + CONTROLS,
                              'Step 3: Z+labsh->CA')
        if r3_ls: results_ls.append(r3_ls)

        write_table(results_ls, "labsh_ca_mediation.md",
                    "Labor Share Mediation: Z -> labsh -> CA")

        if r_step2 and r3_ls:
            print("\n--- Attenuation Report (Labor Share) ---")
            for zvar in ['Z_1', 'Z_2', 'Z_3']:
                atten = compute_attenuation(r_step2, r3_ls, zvar)
                if atten is not None:
                    b_base = r_step2[f'{zvar}_coef']
                    b_med = r3_ls[f'{zvar}_coef']
                    print(f"  {zvar}: {b_base:.4f} -> {b_med:.4f} "
                          f"(attenuation: {atten:.1f}%)")
    else:
        print("\n  Skipping labor share mediation (PWT labsh not available)")

    # ══════════════════════════════════════════════════════════════════
    # OECD SUBSAMPLE
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("OECD SUBSAMPLE: CAPITAL INTENSITY MEDIATION")
    print("=" * 50)

    df_oecd = df[df['iso3'].isin(OECD)].copy()
    print(f"  OECD: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")

    results_oecd = []

    # Step 1: Z -> capital_intensity (OECD)
    print("\n--- OECD Step 1: Z -> capital_intensity ---")
    r_o1 = run_panel_gls(df_oecd, 'capital_intensity',
                         ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                         'OECD Step 1: Z->K')
    if r_o1: results_oecd.append(r_o1)

    # Step 2: Z -> CA (OECD)
    print("\n--- OECD Step 2: Z -> ca_gdp ---")
    r_o2 = run_panel_gls(df_oecd, 'ca_gdp',
                         ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                         'OECD Step 2: Z->CA')
    if r_o2: results_oecd.append(r_o2)

    # Step 3: Z + capital_intensity -> CA (OECD)
    print("\n--- OECD Step 3: Z + capital_intensity -> ca_gdp ---")
    r_o3 = run_panel_gls(df_oecd, 'ca_gdp',
                         ['Z_1', 'Z_2', 'Z_3', 'capital_intensity'] + CONTROLS,
                         'OECD Step 3: Z+K->CA')
    if r_o3: results_oecd.append(r_o3)

    write_table(results_oecd, "automation_ca_mediation_oecd.md",
                "OECD Capital Intensity Mediation: Z -> Investment/GDP -> CA")

    if r_o2 and r_o3:
        print("\n--- OECD Attenuation Report ---")
        for zvar in ['Z_1', 'Z_2', 'Z_3']:
            atten = compute_attenuation(r_o2, r_o3, zvar)
            if atten is not None:
                b_base = r_o2[f'{zvar}_coef']
                b_med = r_o3[f'{zvar}_coef']
                print(f"  {zvar}: {b_base:.4f} -> {b_med:.4f} "
                      f"(attenuation: {atten:.1f}%)")

    # ══════════════════════════════════════════════════════════════════
    # SUMMARY: ATTENUATION TABLE
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("SUMMARY: ATTENUATION TABLE")
    print("=" * 50)

    atten_lines = ["# Mediation Summary: Attenuation of Demographic Channel\n"]
    atten_lines.append("| Mediator | Sample | Z_1 baseline | Z_1 mediated | Attenuation % |")
    atten_lines.append("|:---|:---|---:|---:|---:|")

    # Full sample: capital intensity
    if r_step2 and r_step3:
        b = r_step2.get('Z_1_coef', np.nan)
        m = r_step3.get('Z_1_coef', np.nan)
        a = compute_attenuation(r_step2, r_step3, 'Z_1')
        if a is not None:
            atten_lines.append(f"| Capital intensity | Full | {b:.4f} | {m:.4f} | {a:.1f}% |")

    # Full sample: labor productivity
    if r_step2 and r_step3_lp:
        b = r_step2.get('Z_1_coef', np.nan)
        m = r_step3_lp.get('Z_1_coef', np.nan)
        a = compute_attenuation(r_step2, r_step3_lp, 'Z_1')
        if a is not None:
            atten_lines.append(f"| Labor productivity | Full | {b:.4f} | {m:.4f} | {a:.1f}% |")

    # OECD: capital intensity
    if r_o2 and r_o3:
        b = r_o2.get('Z_1_coef', np.nan)
        m = r_o3.get('Z_1_coef', np.nan)
        a = compute_attenuation(r_o2, r_o3, 'Z_1')
        if a is not None:
            atten_lines.append(f"| Capital intensity | OECD | {b:.4f} | {m:.4f} | {a:.1f}% |")

    atten_lines.append("\n*Attenuation = (baseline - mediated) / baseline x 100. "
                       "Positive values indicate the mediator absorbs part of the Z effect.*")
    atten_lines.append("*Baron & Kenny (1986) mediation framework with PanelGLS.*")

    path = OUT_TABLES / "mediation_summary.md"
    path.write_text('\n'.join(atten_lines))
    print(f"\n  Saved: {path}")

    print("\n" + "=" * 70)
    print("PHASE 3 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
